import copy
import csv
import math
import random
import sys
import time

from matplotlib import pyplot as plt
from xlwt import Workbook


class Solution:  # 解决方案
    def __init__(self):
        self.value = 0.0  # 解
        self.fitness = 0.0  # 解的适应度
        self.cost_distance = 0.0  # 距离成本
        self.cost_time = 0.0  # 时间成本
        self.list_demand = []  # 需求点方案
        self.list_route = []  # 车辆路径方案
        self.list_timetable = []  # 时间点方案


class Node:  # 网络节点
    def __init__(self):
        self.id = 0  # 编号
        self.x = 0  # 坐标x
        self.y = 0  # 坐标y
        self.demand = 0  # 需求量
        self.capacity = 0  # 最大车容量
        self.time_start = 0  # 服务时间窗开始
        self.time_end = 1440  # 服务时间窗结束
        self.time_service = 0  # 节点服务时间


class GA:  # 遗传算法参数
    def __init__(self):
        self.pc = 0.5  # 交叉概率
        self.pm = 0.1  # 突变概率
        self.size = 100  # 种群规模
        self.n_select = 80  # 优良个体选择数量
        self.vehicle_cap = 0  # 车辆容量
        self.vehicle_speed = 1  # 车辆行驶速度
        self.opt_type = 1  # 优化目标类型，0：最小距离成本，1：最小时间成本
        self.best_solution = None  # 全局最优解，值类型为Solution()
        self.list_solution = []  # 种群，值类型为Sol()
        self.dict_demand = {}  # 需求节点集合，值类型为Node()（字典）
        self.list_demand = []  # 需求节点id集合
        self.dict_depot = {}  # 车场节点集合，值类型为Node()（字典）
        self.list_depot = []  # 车场节点id集合
        self.distance_matrix = {}  # 节点距离矩阵
        self.time_matrix = {}  # 节点旅行时间矩阵


# noinspection PyTypeChecker
def get_data(demand_csv, supply_csv, model):  # 获取数据
    try:
        with open(demand_csv, 'r') as f:
            demand_reader = csv.DictReader(f)
            for row in demand_reader:
                node = Node()
                node.id = int(row['编号'])
                node.x = float(row['坐标x'])
                node.y = float(row['坐标y'])
                node.demand = float(row['需求量'])
                node.time_start = float(row['开始时间窗'])
                node.time_end = float(row['结束时间窗'])
                node.time_service = float(row['服务时间'])
                model.dict_demand[node.id] = node
                model.list_demand.append(node.id)

        with open(supply_csv, 'r') as f:
            depot_reader = csv.DictReader(f)
            for row in depot_reader:
                node = Node()
                node.id = row['编号']
                node.x = float(row['坐标x'])
                node.y = float(row['坐标y'])
                node.capacity = float(row['车容量'])
                node.time_start = float(row['开始时间窗'])
                node.time_end = float(row['结束时间窗'])
                model.dict_depot[node.id] = node
                model.list_depot.append(node.id)
    except FileNotFoundError:
        print("注意：没有找到数据文件！")
        time.sleep(10)
        sys.exit()


def select_supply(route, depot_dict, model):  # 选择供应点
    min_in_out_distance = float('inf')
    index = None
    for _, depot in depot_dict.items():
        if depot.capacity > 0:
            in_out_distance = model.distance_matrix[depot.id, route[0]] + model.distance_matrix[route[-1], depot.id]
            if in_out_distance < min_in_out_distance:
                index = depot.id
                min_in_out_distance = in_out_distance
    if index is None:
        print("注意：没有可调度的车辆！")
        time.sleep(10)
        sys.exit()
    route.insert(0, index)
    route.append(index)
    depot_dict[index].capacity = depot_dict[index].capacity - 1
    return route, depot_dict


def matrix_distance_time(model):  # 计算距离&时间矩阵
    for i in range(len(model.list_demand)):
        from_node_id = model.list_demand[i]
        for j in range(i + 1, len(model.list_demand)):
            to_node_id = model.list_demand[j]
            dist = math.sqrt((model.dict_demand[from_node_id].x - model.dict_demand[to_node_id].x) ** 2
                             + (model.dict_demand[from_node_id].y - model.dict_demand[to_node_id].y) ** 2)
            model.distance_matrix[from_node_id, to_node_id] = dist
            model.distance_matrix[to_node_id, from_node_id] = dist
            model.time_matrix[from_node_id, to_node_id] = math.ceil(dist / model.vehicle_speed)
            model.time_matrix[to_node_id, from_node_id] = math.ceil(dist / model.vehicle_speed)
        for _, depot in model.dict_depot.items():
            dist = math.sqrt((model.dict_demand[from_node_id].x - depot.x) ** 2
                             + (model.dict_demand[from_node_id].y - depot.y) ** 2)
            model.distance_matrix[from_node_id, depot.id] = dist
            model.distance_matrix[depot.id, from_node_id] = dist
            model.time_matrix[from_node_id, depot.id] = math.ceil(dist / model.vehicle_speed)
            model.time_matrix[depot.id, from_node_id] = math.ceil(dist / model.vehicle_speed)


def calculate_cost(route_list, model):  # 计算成本
    timetable_list = []
    cost_of_distance = 0
    cost_of_time = 0
    for route in route_list:
        timetable = []
        for i in range(len(route)):
            if i == 0:
                depot_id = route[i]
                next_node_id = route[i + 1]
                travel_time = model.time_matrix[depot_id, next_node_id]
                departure = max(0, model.dict_demand[next_node_id].time_start - travel_time)
                timetable.append((departure, departure))
            elif 1 <= i <= len(route) - 2:
                last_node_id = route[i - 1]
                current_node_id = route[i]
                current_node = model.dict_demand[current_node_id]
                travel_time = model.time_matrix[last_node_id, current_node_id]
                arrival = max(timetable[-1][1] + travel_time, current_node.time_start)
                departure = arrival + current_node.time_service
                timetable.append((arrival, departure))
                cost_of_distance += model.distance_matrix[last_node_id, current_node_id]
                cost_of_time += model.time_matrix[last_node_id, current_node_id] + current_node.time_service + max(
                    current_node.time_start - timetable[-1][1] - travel_time, 0)
            else:
                last_node_id = route[i - 1]
                depot_id = route[i]
                travel_time = model.time_matrix[last_node_id, depot_id]
                departure = timetable[-1][1] + travel_time
                timetable.append((departure, departure))
                cost_of_distance += model.distance_matrix[last_node_id, depot_id]
                cost_of_time += model.time_matrix[last_node_id, depot_id]
        timetable_list.append(timetable)
    return timetable_list, cost_of_time, cost_of_distance


def calculate_fitness(model):  # 计算适应度
    max_obj = -float('inf')
    best_sol = Solution()
    best_sol.value = float('inf')
    for sol in model.list_solution:  # 计算行程距离和行程时间
        node_id_list = copy.deepcopy(sol.list_demand)
        num_vehicle, route_list = split_routes(node_id_list, model)
        timetable_list, cost_of_time, cost_of_distance = calculate_cost(route_list, model)
        if model.opt_type == 0:
            sol.value = cost_of_distance
        else:
            sol.value = cost_of_time
        sol.list_route = route_list
        sol.list_timetable = timetable_list
        sol.cost_distance = cost_of_distance
        sol.cost_time = cost_of_time
        if sol.value > max_obj:
            max_obj = sol.value
        if sol.value < best_sol.value:
            best_sol = copy.deepcopy(sol)
    for sol in model.list_solution:
        sol.fitness = max_obj - sol.value
    if best_sol.value < model.best_solution.value:
        model.best_solution = copy.deepcopy(best_sol)


def plan_routes(node_id_list, pre, model):  # 路径规划
    depot_dict = copy.deepcopy(model.dict_depot)
    route_list = []
    route = []
    label = pre[node_id_list[0]]
    for node_id in node_id_list:
        if pre[node_id] == label:
            route.append(node_id)
        else:
            route, depot_dict = select_supply(route, depot_dict, model)
            route_list.append(route)
            route = [node_id]
            label = pre[node_id]
    route, depot_dict = select_supply(route, depot_dict, model)
    route_list.append(route)
    return route_list


def split_routes(node_id_list, model):  # 路径优化
    supply = model.list_depot[0]
    v = {_id: float('inf') for _id in model.list_demand}
    v[supply] = 0
    pre = {_id: supply for _id in model.list_demand}
    for i in range(len(node_id_list)):
        n_1 = node_id_list[i]
        demand = 0
        departure = 0
        j = i
        cost = 0
        while True:
            n_2 = node_id_list[j]
            demand = demand + model.dict_demand[n_2].demand
            if n_1 == n_2:
                arrival = max(model.dict_demand[n_2].time_start,
                              model.dict_depot[supply].time_start + model.time_matrix[supply, n_2])
                departure = arrival + model.dict_demand[n_2].time_service
                if model.opt_type == 0:
                    cost = model.distance_matrix[supply, n_2] * 2
                else:
                    cost = model.time_matrix[supply, n_2] * 2
            else:
                n_3 = node_id_list[j - 1]
                arrival = max(departure + model.time_matrix[n_3, n_2], model.dict_demand[n_2].time_start)
                departure = arrival + model.dict_demand[n_2].time_service
                if model.opt_type == 0:
                    cost = cost - model.distance_matrix[n_3, supply] + model.distance_matrix[n_3, n_2] + \
                           model.distance_matrix[n_2, supply]
                else:
                    cost = cost - model.time_matrix[n_3, supply] + model.time_matrix[n_3, n_2] \
                           + max(model.dict_demand[n_2].time_start - arrival, 0) + model.time_matrix[n_2, supply]
            if demand <= model.vehicle_cap and departure <= model.dict_demand[n_2].time_end:
                if departure + model.time_matrix[n_2, supply] <= model.dict_depot[supply].time_end:
                    n_4 = node_id_list[i - 1] if i - 1 >= 0 else supply
                    if v[n_4] + cost <= v[n_2]:
                        v[n_2] = v[n_4] + cost
                        pre[n_2] = i - 1
                    j = j + 1
            else:
                break
            if j == len(node_id_list):
                break
    route_list = plan_routes(node_id_list, pre, model)
    return len(route_list), route_list


def initial_solution(model):  # 生成初始解
    demand_id_list = copy.deepcopy(model.list_demand)
    for i in range(model.size):
        seed = int(random.randint(0, 10))
        random.seed(seed)
        random.shuffle(demand_id_list)
        sol = Solution()
        sol.list_demand = copy.deepcopy(demand_id_list)
        model.list_solution.append(sol)


def select_solution(model):  # 选择最优解
    sol_list = copy.deepcopy(model.list_solution)
    model.list_solution = []
    for i in range(model.n_select):
        f1_index = random.randint(0, len(sol_list) - 1)
        f2_index = random.randint(0, len(sol_list) - 1)
        f1_fit = sol_list[f1_index].fitness
        f2_fit = sol_list[f2_index].fitness
        if f1_fit < f2_fit:
            model.list_solution.append(sol_list[f2_index])
        else:
            model.list_solution.append(sol_list[f1_index])


def cross_solution(model):  # 染色体交叉
    sol_list = copy.deepcopy(model.list_solution)
    model.list_solution = []
    while True:
        f1_index = random.randint(0, len(sol_list) - 1)
        f2_index = random.randint(0, len(sol_list) - 1)
        if f1_index != f2_index:
            f1 = copy.deepcopy(sol_list[f1_index])
            f2 = copy.deepcopy(sol_list[f2_index])
            if random.random() <= model.pc:
                cro1_index = int(random.randint(0, len(model.list_demand) - 1))
                cro2_index = int(random.randint(cro1_index, len(model.list_demand) - 1))
                new_c1_f = []
                new_c1_m = f1.list_demand[cro1_index:cro2_index + 1]
                new_c1_b = []
                new_c2_f = []
                new_c2_m = f2.list_demand[cro1_index:cro2_index + 1]
                new_c2_b = []
                for index in range(len(model.list_demand)):
                    if len(new_c1_f) < cro1_index:
                        if f2.list_demand[index] not in new_c1_m:
                            new_c1_f.append(f2.list_demand[index])
                    else:
                        if f2.list_demand[index] not in new_c1_m:
                            new_c1_b.append(f2.list_demand[index])
                for index in range(len(model.list_demand)):
                    if len(new_c2_f) < cro1_index:
                        if f1.list_demand[index] not in new_c2_m:
                            new_c2_f.append(f1.list_demand[index])
                    else:
                        if f1.list_demand[index] not in new_c2_m:
                            new_c2_b.append(f1.list_demand[index])
                new_c1 = copy.deepcopy(new_c1_f)
                new_c1.extend(new_c1_m)
                new_c1.extend(new_c1_b)
                f1.nodes_seq = new_c1
                new_c2 = copy.deepcopy(new_c2_f)
                new_c2.extend(new_c2_m)
                new_c2.extend(new_c2_b)
                f2.nodes_seq = new_c2
                model.list_solution.append(copy.deepcopy(f1))
                model.list_solution.append(copy.deepcopy(f2))
            else:
                model.list_solution.append(copy.deepcopy(f1))
                model.list_solution.append(copy.deepcopy(f2))
            if len(model.list_solution) > model.size:
                break


def mutation_solution(model):  # 染色体突变
    sol_list = copy.deepcopy(model.list_solution)
    model.list_solution = []
    while True:
        f1_index = int(random.randint(0, len(sol_list) - 1))
        f1 = copy.deepcopy(sol_list[f1_index])
        m1_index = random.randint(0, len(model.list_demand) - 1)
        m2_index = random.randint(0, len(model.list_demand) - 1)
        if m1_index != m2_index:
            if random.random() <= model.pm:
                node1 = f1.list_demand[m1_index]
                f1.list_demand[m1_index] = f1.list_demand[m2_index]
                f1.list_demand[m2_index] = node1
                model.list_solution.append(copy.deepcopy(f1))
            else:
                model.list_solution.append(copy.deepcopy(f1))
            if len(model.list_solution) > model.size:
                break


def plot_value(value_list):  # 绘制结果收敛曲线
    plt.plot([(_ + 1) for _ in range(len(value_list))], value_list)
    plt.title('迭代结果收敛曲线')
    plt.xlabel('迭代次数')
    plt.ylabel('最优结果')
    plt.grid()
    plt.xlim(1, len(value_list) + 1)
    plt.show()


def plot_routes(model):  # 绘制路径
    for route in model.best_solution.list_route:
        x_coord = [model.dict_depot[route[0]].x]
        y_coord = [model.dict_depot[route[0]].y]
        for node_id in route[1:-1]:
            x_coord.append(model.dict_demand[node_id].x)
            y_coord.append(model.dict_demand[node_id].y)
        x_coord.append(model.dict_depot[route[-1]].x)
        y_coord.append(model.dict_depot[route[-1]].y)
        plt.grid()
        if route[0] == 'd1':
            plt.plot(x_coord, y_coord, marker='o', color='green', linewidth=0.5, markersize=5)
        elif route[0] == 'd2':
            plt.plot(x_coord, y_coord, marker='o', color='orange', linewidth=0.5, markersize=5)
        else:
            plt.plot(x_coord, y_coord, marker='o', color='b', linewidth=0.5, markersize=5)
    plt.title('运输路径规划图')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()


def put_out(model):  # 输出结果表格
    book = Workbook(encoding='utf-8')
    worksheet = book.add_sheet('统计')
    worksheet.write(0, 0, '时间成本')
    worksheet.write(0, 1, '路程成本')
    worksheet.write(0, 2, '优化方案')
    worksheet.write(0, 3, '最优结果')
    worksheet.write(1, 0, model.best_solution.cost_time)
    worksheet.write(1, 1, model.best_solution.cost_distance)
    worksheet.write(1, 2, '最小距离' if model.opt_type == 0 else '最短时间')
    worksheet.write(1, 3, model.best_solution.value)
    worksheet.write(2, 0, '车辆编号')
    worksheet.write(2, 1, '运输路径')
    worksheet.write(2, 2, '时间表')
    for row, route in enumerate(model.best_solution.list_route):
        worksheet.write(row + 3, 0, 'v' + str(row + 1))
        r = [str(i) for i in route]
        worksheet.write(row + 3, 1, '-'.join(r))
        r = [str(i) for i in model.best_solution.list_timetable[row]]
        worksheet.write(row + 3, 2, '-'.join(r))
    book.save('结果表.xls')
    print('运算结果保存在 结果表.xls')


def main(demand_csv, supply_csv, pc, pm, size, n_select, opt_type, v_cap, v_speed, epoch):  # 总体运行函数
    model = GA()  # 选用遗传算法模型
    get_data(demand_csv, supply_csv, model)  # 获取需求点和供应点参数
    model.pc = pc  # 交叉概率
    model.pm = pm  # 突变概率
    model.size = size  # 种群规模
    model.n_select = n_select  # 优良个体数量
    model.vehicle_cap = v_cap  # 车辆容量
    model.vehicle_speed = v_speed  # 车辆速度
    model.opt_type = opt_type  # 优化目标
    matrix_distance_time(model)  # 计算距离时间矩阵
    initial_solution(model)  # 初始化方案
    history_best_value = []  # 历史最优解记录
    best_solution = Solution()  # 挑选最优解
    best_solution.value = float('inf')
    model.best_solution = best_solution
    for ep in range(epoch):
        calculate_fitness(model)
        select_solution(model)
        cross_solution(model)
        mutation_solution(model)
        history_best_value.append(model.best_solution.value)
        print(f'迭代次数：{ep + 1}/{epochs} 最优结果：{model.best_solution.value}')
    plot_value(history_best_value)
    plot_routes(model)
    put_out(model)
    time.sleep(10)


if __name__ == '__main__':
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 绘图显示中文
    plt.rcParams['axes.unicode_minus'] = False  # 绘图不显示负号
    demand_file = '需求点.csv'
    supply_file = '供应点.csv'
    epochs = int(input('请输入迭代次数：'))
    opt_types = int(input('请输入优化目标类型：（0：最小距离成本，1：最小时间成本）'))
    main(demand_csv=demand_file, supply_csv=supply_file, pc=0.8, pm=0.1, size=100, n_select=80, v_cap=80, v_speed=1,
         opt_type=opt_types, epoch=epochs)
